-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Implements RNNT+MMI #1030
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danpovey Do you have any good idea to test this function, I can only think of constructing simple test cases.
k2/csrc/fsa_algo.cu
Outdated
repeat_num = us_row_splits1_data[us_idx0 + 1] - | ||
us_row_splits1_data[us_idx0]; | ||
|
||
arc.score = -logf(1 - powf(1 - sampling_prob, repeat_num)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only include the "predictor" head output in C++ part, the other two scores (i.e. hybrid output and lm_output) will add on python part, it would be easier to enable autograd for hybrid output.
k2/python/k2/fsa_algo.py
Outdated
a_value = getattr(lattice, "scores") | ||
# Enable autograd for path_scores | ||
b_value = index_select(path_scores.flatten(), arc_map) | ||
value = a_value + b_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
path_scores
here will contain hybrid_output and detached lm_output. I include the path_scores
here and enable antograd to path_scores
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, OK. Right, we treat those as differentiable, but the negated sampling_prob is treated as just a constant.
# index == 0 means the sampled symbol is blank | ||
t_mask = index == 0 | ||
# t_index = torch.where(t_mask, t_index + 1, t_index) | ||
t_index = t_index + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use regular RNN-T, it is possible to generate too many symbols for a specific frame, and that might be chances to generate a lattice containing cycles, which is not expected. I am not sure whether we will encounter such a issue at the very beginning of training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, a valid point. Yes, computing forward backward scores would not work correctly if there are cycles. One possibility would be to augment the state with a sub-frame, i.e. instead of (ctx, t) it becomes (ctx, t, sub_t) with sub_t = (0, 1, 2, ..). That would prevent cycles, although it might prevent a small number of paths from recombining that might otherwise be able to recombine.
It runs normally in my self-constructed test case, not fully tested yet, though.
The sampled paths:
The corresponding lattice:
Note: There is an arc from state 2 to state 17 in the second lattice, because the last symbol of the second path of second sequence is sampled at frame 1, it is a simulation of reaching final frame.